week 8: multilevel models

multilevel adventures

slopes

Let’s start by simulating the cafe data.

# ---- set population-level parameters -----
a <- 3.5       # average morning wait time
b <- (-1)      # average difference afternoon wait time
sigma_a <- 1   # std dev in intercepts
sigma_b <- 0.5 # std dev in slopes
rho <- (-0.7)  #correlation between intercepts and slopes

# ---- create vector of means ----
Mu <- c(a, b)

# ---- create matrix of variances and covariances ----
sigmas <- c(sigma_a,sigma_b) # standard deviations
Rho <- matrix( c(1,rho,rho,1) , nrow=2 ) # correlation matrix
# now matrix multiply to get covariance matrix
Sigma <- diag(sigmas) %*% Rho %*% diag(sigmas)

# ---- simulate intercepts and slopes -----
N_cafes = 20
library(MASS)
set.seed(5)
vary_effects <- mvrnorm( n=N_cafes, mu = Mu, Sigma=Sigma)
a_cafe <- vary_effects[, 1]
b_cafe <- vary_effects[, 2]

# ---- simulate observations -----

set.seed(22)
N_visits <- 10
afternoon <- rep(0:1,N_visits*N_cafes/2)
cafe_id <- rep( 1:N_cafes , each=N_visits )
mu <- a_cafe[cafe_id] + b_cafe[cafe_id]*afternoon
sigma <- 0.5 # std dev within cafes
wait <- rnorm( N_visits*N_cafes , mu , sigma )
d <- data.frame( cafe=cafe_id , afternoon=afternoon , wait=wait )

a simulation note from RM

In this exercise, we are simulating data from a generative process and then analyzing that data with a model that reflects exactly the correct structure of that process. But in the real world, we’re never so lucky. Instead we are always forced to analyze data with a model that is MISSPECIFIED: The true data-generating process is different than the model. Simulation can be used however to explore misspecification. Just simulate data from a process and then see how a number of models, none of which match exactly the data-generating process, perform. And always remember that Bayesian inference does not depend upon data-generating assumptions, such as the likelihood, being true. Non-Bayesian approaches may depend upon sampling distributions for their inferences, but this is not the case for a Bayesian model. In a Bayesian model, a likelihood is a prior for the data, and inference about parameters can be surprisingly insensitive to its details.

Mathematical model:

likelihood function and linear model

\[\begin{align*} W_i &\sim \text{Normal}(\mu_i, \sigma) \\ \mu_i &= \alpha_{CAFE[i]} + \beta_{CAFE[i]}A_i \end{align*}\]

varying intercepts and slopes

\[\begin{align*} \begin{bmatrix} \alpha_{CAFE[i]} \\ \beta_{CAFE[i]} \end{bmatrix} &\sim \text{MVNormal}( \begin{bmatrix} \alpha \\ \beta \end{bmatrix}, \mathbf{S}) \\ \mathbf{S} &\sim \begin{pmatrix} \sigma_{\alpha}, & 0 \\ 0, & \sigma_{\beta}\end{pmatrix}\mathbf{R}\begin{pmatrix} \sigma_{\alpha}, & 0 \\ 0, & \sigma_{\beta}\end{pmatrix} \\ \end{align*}\]

priors

\[\begin{align*} \alpha &\sim \text{Normal}(5,2) \\ \beta &\sim \text{Normal}(-1,0.5) \\ \sigma &\sim \text{Exponential}(1) \\ \sigma_{\alpha} &\sim \text{Exponential}(1) \\ \sigma_{\beta} &\sim \text{Exponential}(1) \\ \mathbf{R} &\sim \text{LKJcorr}(2) \end{align*}\]

LKJ correlation prior

Code
# examples
rlkj_1 = rethinking::rlkjcorr(1e4, K=2, eta=1)
rlkj_2 = rethinking::rlkjcorr(1e4, K=2, eta=2)
rlkj_4 = rethinking::rlkjcorr(1e4, K=2, eta=4)
rlkj_6 = rethinking::rlkjcorr(1e4, K=2, eta=6)
data.frame(rlkj_1= rlkj_1[,1,2], 
           rlkj_2= rlkj_2[,1,2], 
           rlkj_4= rlkj_4[,1,2],
           rlkj_6= rlkj_6[,1,2]) %>% 
  ggplot() +
  geom_density(aes(x=rlkj_1, color = "1"), alpha=.3) +
  geom_density(aes(x=rlkj_2, color = "2"), alpha=.3) +
  geom_density(aes(x=rlkj_4, color = "4"), alpha=.3) +
  geom_density(aes(x=rlkj_6, color = "6"), alpha=.3) +
  labs(x="R", color="eta") +
  theme(legend.position = "top")
m3 <- brm(
  data = d,
  family = gaussian,
  wait ~ 1 + afternoon + (1 + afternoon | cafe),
  prior = c(
    prior( normal(5,2),    class=Intercept ), 
    prior( normal(-1, .5), class=b),
    prior( exponential(1), class=sd),
    prior( exponential(1), class=sigma),
    prior( lkj(2),         class=cor)
  ), 
  iter=2000, warmup=1000, chains=4, cores=4, seed=9,
  file=here("files/models/81.3")
)
posterior_summary(m3)
                                    Estimate  Est.Error          Q2.5
b_Intercept                       3.66309993 0.22512829  3.213083e+00
b_afternoon                      -1.13063289 0.14120596 -1.408561e+00
sd_cafe__Intercept                0.96551600 0.16803123  7.066187e-01
sd_cafe__afternoon                0.59051273 0.12273539  3.874848e-01
cor_cafe__Intercept__afternoon   -0.50623999 0.18222783 -7.985059e-01
sigma                             0.47358446 0.02805115  4.236348e-01
Intercept                         3.09778349 0.19979832  2.705237e+00
r_cafe[1,Intercept]               0.55174600 0.29347667 -9.423002e-03
r_cafe[2,Intercept]              -1.50472165 0.30071773 -2.104292e+00
r_cafe[3,Intercept]               0.70413834 0.29853566  1.176037e-01
r_cafe[4,Intercept]              -0.41979485 0.29178997 -9.814426e-01
r_cafe[5,Intercept]              -1.78797293 0.29822610 -2.362022e+00
r_cafe[6,Intercept]               0.59663781 0.29359765  2.245567e-02
r_cafe[7,Intercept]              -0.04641806 0.29482124 -6.354629e-01
r_cafe[8,Intercept]               0.28169156 0.29935233 -3.081673e-01
r_cafe[9,Intercept]               0.31612410 0.29434697 -2.805685e-01
r_cafe[10,Intercept]             -0.10003272 0.29369089 -6.829916e-01
r_cafe[11,Intercept]             -1.73598803 0.29010317 -2.309541e+00
r_cafe[12,Intercept]              0.17830850 0.29047523 -3.977328e-01
r_cafe[13,Intercept]              0.21754651 0.29824631 -3.748931e-01
r_cafe[14,Intercept]             -0.48794486 0.29898931 -1.083548e+00
r_cafe[15,Intercept]              0.79224377 0.30177609  2.165891e-01
r_cafe[16,Intercept]             -0.27285927 0.29415940 -8.591375e-01
r_cafe[17,Intercept]              0.55395308 0.29658903 -2.563127e-02
r_cafe[18,Intercept]              2.08452541 0.29755910  1.515073e+00
r_cafe[19,Intercept]             -0.41612887 0.29589644 -1.000933e+00
r_cafe[20,Intercept]              0.06636752 0.29841240 -5.261121e-01
r_cafe[1,afternoon]              -0.02570753 0.28881915 -5.997060e-01
r_cafe[2,afternoon]               0.22656232 0.29510159 -3.598793e-01
r_cafe[3,afternoon]              -0.79745302 0.29743733 -1.391656e+00
r_cafe[4,afternoon]              -0.10154505 0.28200814 -6.546762e-01
r_cafe[5,afternoon]               0.99519559 0.29618891  4.155578e-01
r_cafe[6,afternoon]              -0.16375475 0.28185500 -7.209486e-01
r_cafe[7,afternoon]               0.10326646 0.28112073 -4.649813e-01
r_cafe[8,afternoon]              -0.50195539 0.29310739 -1.088135e+00
r_cafe[9,afternoon]              -0.17475020 0.28373869 -7.312723e-01
r_cafe[10,afternoon]              0.17791484 0.29069257 -3.867201e-01
r_cafe[11,afternoon]              0.70056987 0.28544648  1.528820e-01
r_cafe[12,afternoon]             -0.05544267 0.28584966 -6.327759e-01
r_cafe[13,afternoon]             -0.67808867 0.29182387 -1.248284e+00
r_cafe[14,afternoon]              0.19279510 0.28668008 -3.613501e-01
r_cafe[15,afternoon]             -1.06138413 0.30950531 -1.649146e+00
r_cafe[16,afternoon]              0.09023589 0.28321804 -4.699439e-01
r_cafe[17,afternoon]             -0.08881056 0.28260403 -6.366611e-01
r_cafe[18,afternoon]              0.10274907 0.30214813 -4.871265e-01
r_cafe[19,afternoon]              0.87094086 0.29872493  2.938193e-01
r_cafe[20,afternoon]              0.07455140 0.29312248 -4.916865e-01
lprior                           -5.06080142 0.43581160 -6.072327e+00
lp__                           -197.19938164 7.15820084 -2.120673e+02
                                       Q97.5
b_Intercept                       4.10670135
b_afternoon                      -0.84752548
sd_cafe__Intercept                1.36573636
sd_cafe__afternoon                0.86860280
cor_cafe__Intercept__afternoon   -0.09558939
sigma                             0.53332102
Intercept                         3.49665640
r_cafe[1,Intercept]               1.13941891
r_cafe[2,Intercept]              -0.89622321
r_cafe[3,Intercept]               1.30047421
r_cafe[4,Intercept]               0.14750442
r_cafe[5,Intercept]              -1.20089749
r_cafe[6,Intercept]               1.17820219
r_cafe[7,Intercept]               0.54793940
r_cafe[8,Intercept]               0.86131427
r_cafe[9,Intercept]               0.90304163
r_cafe[10,Intercept]              0.49032131
r_cafe[11,Intercept]             -1.17286774
r_cafe[12,Intercept]              0.75797210
r_cafe[13,Intercept]              0.82197923
r_cafe[14,Intercept]              0.09498970
r_cafe[15,Intercept]              1.39352140
r_cafe[16,Intercept]              0.31579626
r_cafe[17,Intercept]              1.15364945
r_cafe[18,Intercept]              2.67599091
r_cafe[19,Intercept]              0.15617338
r_cafe[20,Intercept]              0.64711982
r_cafe[1,afternoon]               0.53465861
r_cafe[2,afternoon]               0.79151024
r_cafe[3,afternoon]              -0.23584882
r_cafe[4,afternoon]               0.42989679
r_cafe[5,afternoon]               1.57150918
r_cafe[6,afternoon]               0.37630537
r_cafe[7,afternoon]               0.64220947
r_cafe[8,afternoon]               0.07470165
r_cafe[9,afternoon]               0.37594150
r_cafe[10,afternoon]              0.75426810
r_cafe[11,afternoon]              1.27372879
r_cafe[12,afternoon]              0.49754018
r_cafe[13,afternoon]             -0.12152580
r_cafe[14,afternoon]              0.78474839
r_cafe[15,afternoon]             -0.43672292
r_cafe[16,afternoon]              0.63835641
r_cafe[17,afternoon]              0.46710868
r_cafe[18,afternoon]              0.70327918
r_cafe[19,afternoon]              1.47137045
r_cafe[20,afternoon]              0.66350230
lprior                           -4.35587362
lp__                           -184.27360638

Let’s get the slopes and intercepts for each cafe.

Code
intercepts = coef(m3)$cafe[ ,, "Intercept"]
slopes = coef(m3)$cafe[,, "afternoon"]
cafe_params = data.frame(
  cafe=1:20,
  intercepts=intercepts[, 1],
  slopes=slopes[, 1]
) 
cafe_params
   cafe intercepts     slopes
1     1   4.214846 -1.1563404
2     2   2.158378 -0.9040706
3     3   4.367238 -1.9280859
4     4   3.243305 -1.2321779
5     5   1.875127 -0.1354373
6     6   4.259738 -1.2943876
7     7   3.616682 -1.0273664
8     8   3.944791 -1.6325883
9     9   3.979224 -1.3053831
10   10   3.563067 -0.9527181
11   11   1.927112 -0.4300630
12   12   3.841408 -1.1860756
13   13   3.880646 -1.8087216
14   14   3.175155 -0.9378378
15   15   4.455344 -2.1920170
16   16   3.390241 -1.0403970
17   17   4.217053 -1.2194435
18   18   5.747625 -1.0278838
19   19   3.246971 -0.2596920
20   20   3.729467 -1.0560815
Code
cafe_params %>% 
  ggplot( aes(x=intercepts, y=slopes) ) +
  geom_point(size = 2) 
Code
cafe_params %>% 
  ggplot( aes(x=intercepts, y=slopes) ) +
  stat_ellipse() +
  geom_point(size = 2) 
Code
cafe_params %>% 
  ggplot( aes(x=intercepts, y=slopes) ) +
  mapply(function(level) {
    stat_ellipse(geom  = "polygon", type = "norm",
                 linewidth = 0, alpha = .1, fill = "#1c5253",
                 level = level)
    }, 
    # enter the levels here
    level = c(1:9 / 10, .99)) +
  geom_point(size = 2) 

More about stat_ellipse here.

exercise

Now use the slopes and intercepts to calculate the expected morning and afternoon wait times for each cafe. Plot these as a scatterplot. Bonus for ellipses.

Code
cafe_params %>% 
  mutate(
    morning = intercepts, 
    afternoon = intercepts + slopes
  ) %>% 
  ggplot( aes(x=morning, y=afternoon) ) +
  mapply(function(level) {
    stat_ellipse(geom  = "polygon", type = "norm",
                 linewidth = 0, alpha = .1, fill = "#1c5253",
                 level = level)
    }, 
    # enter the levels here
    level = c(1:9 / 10, .99)) +
  geom_point(size = 2)+
  labs(x="morning wait time",
       y="afternoon wait time")

What is the covariance of our intercepts and slopes?

Code
post = as_draws_df(m3)
rlkj_2 = rethinking::rlkjcorr(nrow(post), K=2, eta=2)

data.frame(prior= rlkj_2[,1,2],
           posterior = post$cor_cafe__Intercept__afternoon) %>% 
  ggplot() +
  geom_density(aes(x=prior, color = "prior"), alpha=.3) +
  geom_density(aes(x=posterior, color = "posterior"), alpha=.3) +
  labs(x="R") +
  theme(legend.position = "top")

advanced varying slopes

Let’s return to the chimp experiment example. As a reminder, our primary outcome is whether the chimpanzee pulls the LEFT lever (binary). In the data, there are multiple clusters:

  • actor: chimps undergo multiple trials in the experiment.
  • block_id: chimps partipate in the experiment in different blocks, on different days

We also want to know the effects of different features the experimenter can manipulate: whether the prosocial option is on the left or right side (prosoc_left), and whether there is another chimpanzee present or not (condition). To simplify, we’ll combine this 2x2 into a single variable with 4 options (treatment).

We’ll fit a CROSS-CLASSIFIED VARYING SLOPES model. Fun!

\[\begin{align*} L_i &\sim \text{Binomial}(1, p_i) \\ \text{logit}(p_i) &= \gamma_{\text{TID}[i]} + \alpha_{\text{ACTOR}[i]\text{TID}[i]} + \beta_{\text{BLOCK}[i]\text{TID}[i]} \\ \begin{bmatrix} \alpha_{j,1} \\ \alpha_{j,2} \\ \alpha_{j,3} \\ \alpha_{j,4} \end{bmatrix} &\sim \text{MVNormal}\begin{pmatrix} \begin{bmatrix} 0\\0\\0\\0 \end{bmatrix}, \mathbf{S}_{\text{ACTOR}}\end{pmatrix} \\ \begin{bmatrix} \beta_{j,1} \\ \beta_{j,2} \\ \beta_{j,3} \\ \beta_{j,4} \end{bmatrix} &\sim \text{MVNormal}\begin{pmatrix} \begin{bmatrix} 0\\0\\0\\0 \end{bmatrix}, \mathbf{S}_{\text{BLOCK}}\end{pmatrix} \\ \end{align*}\]

And the rest of the model will look like our cafe model from before.

data(chimpanzees, package="rethinking")
d <- chimpanzees
d$treatment = as.factor(1 + d$prosoc_left + 2*d$condition)

m4 <- brm(
  data = d, 
  family = bernoulli(link = "logit"),
  pulled_left ~ 0 + treatment + (0 + treatment | actor) + (0 + treatment | block),
  prior = c(
    prior( normal(0, 1),   class=b),
    prior( exponential(1), class=sd),
    prior( lkj(2),         class=cor)), 
  iter=2000, warmup=1000, chains=4, cores=4, seed=9,
  file = here("files/models/81_4"))
print(m4)
 Family: bernoulli 
  Links: mu = logit 
Formula: pulled_left ~ 0 + treatment + (0 + treatment | actor) + (0 + treatment | block) 
   Data: d (Number of observations: 504) 
  Draws: 4 chains, each with iter = 2000; warmup = 1000; thin = 1;
         total post-warmup draws = 4000

Multilevel Hyperparameters:
~actor (Number of levels: 7) 
                           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
sd(treatment1)                 1.39      0.47     0.69     2.52 1.00     2211
sd(treatment2)                 0.91      0.41     0.31     1.88 1.00     3216
sd(treatment3)                 1.86      0.58     1.02     3.27 1.00     3926
sd(treatment4)                 1.57      0.59     0.73     3.01 1.00     3435
cor(treatment1,treatment2)     0.42      0.28    -0.19     0.87 1.00     3550
cor(treatment1,treatment3)     0.52      0.25    -0.04     0.90 1.00     3288
cor(treatment2,treatment3)     0.49      0.26    -0.11     0.88 1.00     3054
cor(treatment1,treatment4)     0.44      0.27    -0.15     0.86 1.00     2822
cor(treatment2,treatment4)     0.45      0.27    -0.16     0.87 1.00     3470
cor(treatment3,treatment4)     0.58      0.24     0.02     0.92 1.00     3462
                           Tail_ESS
sd(treatment1)                 2842
sd(treatment2)                 2806
sd(treatment3)                 3095
sd(treatment4)                 2981
cor(treatment1,treatment2)     2956
cor(treatment1,treatment3)     2968
cor(treatment2,treatment3)     3384
cor(treatment1,treatment4)     2560
cor(treatment2,treatment4)     3614
cor(treatment3,treatment4)     3553

~block (Number of levels: 6) 
                           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS
sd(treatment1)                 0.41      0.32     0.01     1.19 1.00     1919
sd(treatment2)                 0.44      0.34     0.02     1.27 1.00     1993
sd(treatment3)                 0.30      0.27     0.01     0.98 1.00     3499
sd(treatment4)                 0.47      0.37     0.02     1.39 1.00     2219
cor(treatment1,treatment2)    -0.06      0.37    -0.75     0.66 1.00     5144
cor(treatment1,treatment3)    -0.02      0.38    -0.73     0.70 1.00     6994
cor(treatment2,treatment3)    -0.03      0.38    -0.73     0.69 1.00     5766
cor(treatment1,treatment4)     0.05      0.38    -0.67     0.72 1.00     5567
cor(treatment2,treatment4)     0.05      0.37    -0.66     0.72 1.00     4380
cor(treatment3,treatment4)     0.02      0.38    -0.68     0.73 1.00     3058
                           Tail_ESS
sd(treatment1)                 2006
sd(treatment2)                 2184
sd(treatment3)                 2667
sd(treatment4)                 2463
cor(treatment1,treatment2)     2969
cor(treatment1,treatment3)     2497
cor(treatment2,treatment3)     3386
cor(treatment1,treatment4)     2871
cor(treatment2,treatment4)     3486
cor(treatment3,treatment4)     3195

Regression Coefficients:
           Estimate Est.Error l-95% CI u-95% CI Rhat Bulk_ESS Tail_ESS
treatment1     0.20      0.51    -0.83     1.22 1.00     2373     2673
treatment2     0.64      0.41    -0.18     1.46 1.00     3301     2923
treatment3    -0.03      0.59    -1.17     1.20 1.00     2854     2578
treatment4     0.67      0.54    -0.42     1.72 1.00     3253     3159

Draws were sampled using sampling(NUTS). For each parameter, Bulk_ESS
and Tail_ESS are effective sample size measures, and Rhat is the potential
scale reduction factor on split chains (at convergence, Rhat = 1).

Let’s visualize the estimates. We’ll select just one of the 6 blocks.

Code
d$labels = factor(d$treatment, 1:4, labels=c("r/n", "l/n", "r/p", "l/p"))

nd <-
  d %>% 
  distinct(actor, condition, labels, prosoc_left, treatment) %>% 
  mutate(block = 5)

# compute and wrangle the posterior predictions
nd = fitted(m4,
       newdata = nd) %>% 
  data.frame() %>% 
  bind_cols(nd) %>% 
  # add the empirical proportions
  left_join(
    d %>%
      group_by(actor, treatment) %>%
      mutate(proportion = mean(pulled_left)) %>% 
      distinct(actor, treatment, proportion),
    by = c("actor", "treatment")
  ) %>% 
  mutate(condition = factor(condition),
         prosoc_left = factor(prosoc_left)) 

# for annotation
text <-
  distinct(d, labels) %>% 
  mutate(actor = 1,
         prop  = c(.07, .8, .08, .795))
  # plot!
  nd %>% ggplot(aes(x = labels)) +
  geom_hline(yintercept = .5,  alpha = 1/2, linetype = 2) +
  # empirical proportions
  geom_point(aes(y = proportion, shape = condition, color = "raw"),
             fill = "white", size = 2.5, show.legend = F) + 
  # posterior predictions
  geom_line(aes(y = Estimate, group = prosoc_left, color = "model"),
            linewidth = 3/4) +
  geom_pointrange(aes(y = Estimate, ymin = Q2.5, ymax = Q97.5, shape = condition,
                      color = "model"),
                  fill = "white", fatten = 8, linewidth = 1/3, show.legend = F) + 
  # annotation for the conditions
  geom_text(data = text,
            aes(y = prop, label = labels), size = 5) +
  guides(color = FALSE) +
  scale_shape_manual(values = c(21, 19)) +
  scale_color_manual(values = c("#1c5253", "black")) +
  scale_x_discrete(NULL, breaks=NULL) +
  scale_y_continuous("proportion left lever", 
                     breaks = 0:2 / 2, 
                     labels = c("0", ".5", "1"))  +
  facet_wrap(~ actor, nrow = 1, labeller = label_both)